{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center/>混合精度训练体验"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "神经网络训练的时候，数据和权重等各种参数一般使用单精度浮点数（float32）进行计算和存储。在采用复杂神经网络进行训练时，由于计算量的增加，机器的内存开销变得非常大。经常玩模型训练的人知道，内存资源的不足会导致训练的效率变低，简单说就是训练变慢，有没有什么比较好的方法，在不提升硬件资源的基础上加快训练呢？这次我们介绍其中一种方法--混合精度训练，说白了就是将参数取其一半长度进行计算，即使用半精度浮点数（float16）计算，这样就能节省一半内存开销。当然，为了保证模型的精度，不能把所有的计算参数都换成半精度。为了兼顾模型精度和训练效率，MindSpore在框架中设置了一个自动混合精度训练的功能，本次体验我们将使用ResNet-50网络进行训练，体验MindSpore混合精度训练和单精度训练的不同之处。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整体过程如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. MindSpore混合精度训练的原理介绍。\n",
    "2. 数据集准备。\n",
    "3. 定义动态学习率。\n",
    "4. 定义损失函数。\n",
    "5. 定义ResNet-50网络。\n",
    "6. 定义`One_Step_Time`回调函数。\n",
    "7. 定义训练网络（此处设置自动混合精度训练参数`amp_level`）。\n",
    "8. 验证模型精度。\n",
    "9. 混合精度训练和单精度训练的对比。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 你可以在这里找到完整可运行的样例代码：<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet>。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MindSpore混合精度训练原理介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://www.mindspore.cn/tutorial/zh-CN/master/_images/mix_precision.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 参数以FP32存储；\n",
    "2. 正向计算过程中，遇到FP16算子，需要把算子输入和参数从FP32 `cast`成FP16进行计算；\n",
    "3. 将Loss层设置为FP32进行计算；\n",
    "4. 反向计算过程中，首先乘以Loss Scale值，避免反向梯度过小而产生下溢；\n",
    "5. FP16参数参与梯度计算，其结果将被cast回FP32；\n",
    "6. 除以`Loss scale`值，还原被放大的梯度；\n",
    "7. 判断梯度是否存在溢出，如果溢出则跳过更新，否则优化器以FP32对原始参数进行更新。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上可以理解(float16为半精度浮点数，float32为单精度浮点数)，MindSpore是将网络中的前向计算部分`cast`成半精度浮点数进行计算，以节省内存空间，提升性能，同时将`loss`值保持单精度浮点数进行计算和存储，`weight`使用半精度浮点数进行计算，单精度浮点数进行保存，通过这样操作即提升了训练效率，又保证了一定的模型精度，达到提升训练性能的目的。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据集下载地址：<https://www.cd.toronto.edu/~kriz/cifar-10-python.tar.gz>。\n",
    "\n",
    "数据集下载后，解压至jupyter的工作路径+/datasets/cifar10，由于测试数据集和训练数据集在一个文件夹中，需要你分开两个文件夹存放，存放形式如下。"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "cifar10\n",
    "├── test\n",
    "│   └── test_batch.bin\n",
    "└── train\n",
    "    ├── data_batch_1.bin\n",
    "    ├── data_batch_2.bin\n",
    "    ├── data_batch_3.bin\n",
    "    ├── data_batch_4.bin\n",
    "    └── data_batch_5.bin\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果放置正确，可以在开启的jupyter的首页网址+`/tree/datasets/cifar10`，找到`test`和`train`文件夹。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据增强"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "先将CIFAR-10的原始数据集可视化："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the cifar dataset size is : 50000\n",
      "the tensor of image is: (32, 32, 3)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAActElEQVR4nO2da4xdV3XH/+vcx7z8jmPHOE4MITwCgUDdNFIookBRSpECUqHQCuUDxagCqUj0Q0SlQqV+gKqA+FBRmRKRVpSQ8hBpQS0ogiIqNWQIwQlxCI7jhImfiV8znrmvc1Y/3BvkhP1fM57HvSn7/5Msz+x19znr7nvWPXf2/661zN0hhPjNpxi1A0KI4aBgFyITFOxCZIKCXYhMULALkQkKdiEyob6SyWZ2E4DPAqgB+Cd3/0T0+ImJSd+wcRM7WnQiYuCyYXA0uFfBPD6zIjJlN5hTG2tym/H32jLwsVbU+LxuLzletefpnCJceu5jtMZWpOex8f7x+BGj18yrMvCEXSPLud6ASKleroxdLiwkx9udFp9US69ju9NFr1cmn8Cyg93MagD+AcDvA5gBcK+Z3eXuD7E5GzZuwh//6Z+lj1fjrtTqjeR4FbzIteC17HXTiwsA9SCQOr1OcvxJT/sHAJt3v4ja1k2MU9tsEJzrxjdQ29mTTyfHF37+EzpnfZMHUmNsktqiN6vmWPq5NafWB8fj69hpzXFbe5baUHXT4wW/3ooi8KPLr7my4sFelnyN5x7+WXL80UMH6Jxq40Ry/OGDv6RzVvIx/noAB939kLt3ANwB4OYVHE8IsYasJNh3ArjwbWRmMCaEeB6ykmBPfVD+tc8xZrbXzKbNbHphnn80FUKsLSsJ9hkAuy74/XIAR577IHff5+573H3PxCT/+08IsbasJNjvBXC1mb3QzJoA3g3grtVxSwix2ix7N97de2b2IQD/hb70dpu7p7cVfzUH6FVsmzyQoZDeIe90yU4rgMLTO+cA4JGkEcgnHbKh+th8egccAGznFdS2+ZIt1LZpfB21PXns1z5A/YqHjs0kx7ct8Odca/F1bExxWWOc7LgDgFv6mFWN72YXRHUBgFaLv5694LkVlj6fI7h2anznvApCplfya6fbDaRDslNvgZLXJUpUJP+tSGd3928D+PZKjiGEGA76Bp0QmaBgFyITFOxCZIKCXYhMULALkQkr2o2/eBwo01lZQVITld6CpDdUbS7HdBfOU1sRJOT4eDqJ47LtXEKrAplv5iSX0Lqd4NuGQQLQtnXpLy7ViAQFAIiy6Npc8krnVvVpNEjyEnn9AaAK5NIqSDIJTPCKyFrBc47kq6ri/ofXI/EDAHrk4o/uxAVNQuIviu7sQmSCgl2ITFCwC5EJCnYhMkHBLkQmDHU33t1R9tIJCB7WJkvvgHZawY51sAvuQTmlWpMndxS19A7zutoYnVOniT+Az/PyWO1ATagHO8LrWul5wR4y6uM89bgIdnfZawkAVTe9s251vivNyn4BwMICf63rUTE8croiql8Y3AIjVSAqS8XqFwJARX3hc3wZc3RnFyITFOxCZIKCXYhMULALkQkKdiEyQcEuRCYMXXrrMXkl6EHk7XZyvNvinUAaY7yeWb3OpTILbHXW0mj2NJ3TOcfr0xWRZBe9D7d5dxRrnUuOrx9LdxABgHqdy41W534UJfe/TerC1Qsuoc0FpQGfOPQwtV155S5qa5CEl24vqjPHk4aCfBb0Sj6v2wmOybrFRN2kltFpSnd2ITJBwS5EJijYhcgEBbsQmaBgFyITFOxCZMKKpDczOwxgFkAJoOfue6LHuzs6pKZZpx1ksNXS70lFLaidFpRc65Y8W6tW8WMWRVoi4cIV4OASoJX8vXaCPGcAqAouefWIdMgyBwGgdC57RqXrUAT1+hrpVTm/kJZRAaDR5lmAu+v8NfMg+65XT2tUlfP19aCN03LVMAukZTNiY+MAaqxWYjBnNXT233P3p1bhOEKINUQf44XIhJUGuwP4jpn92Mz2roZDQoi1YaUf42909yNmtg3Ad83sYXf/wYUPGLwJ7AWAqSnehlgIsbas6M7u7kcG/58A8A0A1yces8/d97j7nrFx/v1sIcTasuxgN7MpM1v/zM8A3gLgwdVyTAixuqzkY/x2AN8YyAZ1AP/q7v8ZTSgAjNfS0kAx2aTzas30JwLW2gcArOTFC521k8IiEgmxFVFLo+B40VttVLzQAj2syVScoN1RLZAwIz0pKjjJ2hP1evySe3LmMLXt2rqB2irwa6fqpH2sB9l8YFloWKzgZNDiKch6K4mPZZhFl76+o9ZVyw52dz8E4NXLnS+EGC6S3oTIBAW7EJmgYBciExTsQmSCgl2ITBhqwcmJiTFc+4oXJm0WSDwswyeStZqBnGRB/7KoJxeToepBhlq9HmTm9bgf4+Prqa3bOk9tC8RWBlJklCHY6wWSUdCbjUlAR2f5+j4wzQt3Xn4JX4/L1vPL+NDBx5Pjm6a4lLd5y2ZqC5Ll0A2kyIVmIMttTZ/PxvjzOt9Ir+PjMzwnTXd2ITJBwS5EJijYhcgEBbsQmaBgFyIThrob315YwGP704lx3WATvEXq1s2ePUPnbNzAd1vHp3iqbRElrpCN6Xabt6GqnCfJuPHlf+3r3kRtjz7yM2o7c/J4cnyiyc/Vq/guMslnAQDU6/yY45NTyfGHHjlJ5zx9mrfKOrtwCbWNzZ2itoMz6d34dpe/Lq9+9TXUtnnzRm7byOs1vGA9n9fbnn5uC6TtGQDU16ev4YcefoTO0Z1diExQsAuRCQp2ITJBwS5EJijYhcgEBbsQmTBU6a2o1TC1KS1BtIIMg85sujWUBdLEBDkPAGy9ZCu1bdiwidpA6o8dOfIYnXL6DE/u8EjXChJ5CtJaCQDqzbTkVQV16zxoh+XdoB4btQCnT80lx48dPUbntNo8wWfm0BPUdu70WWqbXUi3lJqb5XLpj+69l9o2rk+vLwC84sUvobZx8roAwJHH0s9tbv4cnXPJNiLXzfE11J1diExQsAuRCQp2ITJBwS5EJijYhcgEBbsQmbCo9GZmtwF4G4AT7v7KwdgWAF8BsBvAYQDvcneuMQ2oyhJz59KZat2KZ5t1FtISW7MeZKh1W9R06vhRajt5lNuYGtZqp+UdALBA1ioKvvw1C+rrBZl54yS77dxZLsl0IwlzYozaFjpcfDt89ERy/NRpXiOtQVqDAYA3+H3p3HxUky/92jSDjL2Feb4ezUYQMsFrVga31VaRlnR9jGdnlo10yysPfFjKnf2LAG56ztitAO5296sB3D34XQjxPGbRYB/0W39uwvDNAG4f/Hw7gLevsl9CiFVmuX+zb3f3owAw+H/b6rkkhFgL1nyDzsz2mtm0mU0vtPjf0UKItWW5wX7czHYAwOD/9G4MAHff5+573H3PxDj/TrcQYm1ZbrDfBeCWwc+3APjm6rgjhFgrliK9fRnAGwBsNbMZAB8D8AkAd5rZ+wA8AeCdSzlZr6zw9Gz6o3y3xwsAdjrpgpNRp6aqx2WhqNXUQofLLkzVCBQj1MCflwWZbfOB/NOa51lejx1+NDn+4CO/oHOqoP/Tju1bqK1T8jU+/lQ6Y6vd5eeKin2ePsOfcydou+SePl+jwSXFrdv4FlSz4K/nZLPB/WC9wwBsXp8uVNkNWm9tnEzLcrUiaEVGLQPc/T3ExMufCiGed+gbdEJkgoJdiExQsAuRCQp2ITJBwS5EJgy14CQAONLyigUZYGPN9HtSUXDpql7nMkitxm31cS7jgEhDFvRza9a4PFULnnOUbTY+zm3HTqb7pVV1vlYerONTZ9OFIwGgIgU4AS6lBiolvTYAYKGVll8BwLj7mBxLZ4dt2rCezrn22mup7cgTvJdat8O/IVoEff2YFNwhPQ4BACWTZvlroju7EJmgYBciExTsQmSCgl2ITFCwC5EJCnYhMmG4vd4Kw/h4Wgppd7l8VZHkn6jepHsgWwTpclEGmxOpqVdxua4VZIZVgR/n53kRy8dnjlDbmfPp4otXXHpZcC6eYffU2bSUBwCtNpeauh0ivdWWd8n12EUAoBHIiixjshbMefHuXdS2oeB+lEFGXy/I6pyaTPeBqzcCiZhkt1kgX+rOLkQmKNiFyAQFuxCZoGAXIhMU7EJkwlB34+v1GrZtTe88zp3nO7usnBzbkQSAIqrFFdR+Kyw4Zi1tawXtk1pdbiMb1gCAH983TW37H+L15HbtujI5vmXzJjrn8BO/pLbZ2XQtOYCvBwD0SO23KkjuqEeJTcG5yi5XQ84TxWDiHFc7OkErsrmg5dWZ07PUNj7JKyvPk+unHVxXZxfS69sN6gnqzi5EJijYhcgEBbsQmaBgFyITFOxCZIKCXYhMWEr7p9sAvA3ACXd/5WDs4wDeD+Dk4GEfdfdvL3asibEarnlRWgLqdYNabfV0zbVI+onq00WV0Io6P6aTxJUykNC6vXRiCgA8eew0td357/9Nbb9z7VXU9rs3/FZyvB34+OWTx6itEyQojRW8Ft6u7VuT43MkUQcAzgbJP2HbqEAuZf53Arnu0o3p1koAMPnindTWaqWfMwA0x7iP5+fS8mDkY0USg/YfeJjOWcqd/YsAbkqMf8bdrxv8WzTQhRCjZdFgd/cfADg1BF+EEGvISv5m/5CZ7Tez28xs86p5JIRYE5Yb7J8DcBWA6wAcBfAp9kAz22tm02Y2fW5ufpmnE0KslGUFu7sfd/fS3SsAnwdwffDYfe6+x933bFg3uVw/hRArZFnBbmY7Lvj1HQAeXB13hBBrxVKkty8DeAOArWY2A+BjAN5gZteh32vmMIAPLOVkhRVYN56+u1s6Ga7vZDNdi6sWyWvO38dabS7jtDs804i11qloKx5gLJDyZn55lNp2XXYptb35927g52uka/z98H9/Qud4i/95tfuKy6mtBi6Xvv9P/jA5HkloBw4+QW3/M83vJ6fPcDnvzTdelxz/6YHH+PGO8hZPN/72y6gtalFVD4obsjKFvEIhwNTGb32Hy6GLBru7vycx/IXF5gkhnl/oG3RCZIKCXYhMULALkQkKdiEyQcEuRCYMteBkr9fF6adOJG1TU1wyKFl2W9A+qSq5rRVkcoUtd4jNSXFFAOgs8HNNBHLMK6/aRm3o8eKcTx47nhz/xaHDdM61L08XqQSAex84SG1v+J1rqO2KnZckx7tBwcmX7d5CbZdv20Btd3zrHmp73fUvTY5ffUXaPwB4/IkZarvhVXytqh7PUgsuEYw1iIQctjdL63UWCHa6swuRCQp2ITJBwS5EJijYhcgEBbsQmaBgFyIThiq9VWWJ82fTFa68k87WAoCJibStqgLprce1jm6Py2HFBO/J1eum/TDjy9gI9JMdpO9d/5iBdHiOVwk7djxte8kVXMo7dYb3KNuyjkuiL7tyO7WdP5MupmmBXFor+VqtDwo2XrqZr2Ovnc7ou/6ll9E535/j63EwyMx7wfaN1DYf9GCbHE9fV/UgY7IgS2VEkgN0ZxciGxTsQmSCgl2ITFCwC5EJCnYhMmGou/Fl2cPZU+nd4jMeJIyQ3UrSAQdAmCNDW+cAcSJMSXbWreDLWKtzWxWpAsZ3ps88xXfjT59K74Kva/AFKVu8htuWKV7n78SRI9TGsjjKHk+EqQWJQafn+bydm4LL+PzZ5PBjh9PjALBlgvtxPiiHfm6cr9XpU09TW6OR9r/JEmQA1Ml2fKfDk3F0ZxciExTsQmSCgl2ITFCwC5EJCnYhMkHBLkQmLKX90y4A/wzgMgAVgH3u/lkz2wLgKwB2o98C6l3untZ9BlRlifmzacmjChI/5mfTtqDMHMogScarIFmgzuUOI3JScKqQqH1VWQXtggJZriTSy/kWf85VIIfNBq2yZp7k0lu7nfYjWnsr+EKebXE/jp9doLafP0pacwUv2jiRwgBg4ybenfxYcMyTx3irL5DXOihBR2vNtRZ4K7Kl3Nl7AD7i7i8HcAOAD5rZNQBuBXC3u18N4O7B70KI5ymLBru7H3X3+wY/zwI4AGAngJsB3D542O0A3r5WTgohVs5F/c1uZrsBvAbAPQC2u/tRoP+GACCofSyEGDVLDnYzWwfgawA+7O7nLmLeXjObNrPpefJ3nBBi7VlSsJtZA/1A/5K7f30wfNzMdgzsOwAkuz+4+z533+PueybH+PfOhRBry6LBbmaGfj/2A+7+6QtMdwG4ZfDzLQC+ufruCSFWi6Vkvd0I4L0AHjCz+wdjHwXwCQB3mtn7ADwB4J2LHqmqUHbSrYtYRhkAOJXlov443FRE5+pe/J8aFkhhEZ0Ol0naXS5RtYL6eg3SKssjSTHIAtw0yWvyeSA1zS+kX+caa+UFoGCF1QBUQcuuyQaf1+2k5zWCC6QEX/uzT5+ktvNB7boiOCYjys6sSK05D57XosHu7j8Ej6o3LTZfCPH8QN+gEyITFOxCZIKCXYhMULALkQkKdiEyYagFJ82AJslsaneD6pHkLWmsEck4wftYIJV51J6ITQsy9kJ50LiPZZDS1ymjtkDMxo+3ocl9rNX561IzbpvakL60orZWUabiJc1AOjT+ZS22HL3gegteFjSDiOn0eDHKQHGkV4gHBTjN0usRqJe6swuRCwp2ITJBwS5EJijYhcgEBbsQmaBgFyIThiq9AUBBpJdGIDPUiJ5QD9+qeJZRtwzkpEC7YNlyVXnxGU1AXKiyWeNS01iTP/F2hxePZBRB4ctul2fmWVCYcZy8npGsVUVFNoMMsEimZAlxjaD3XSSXFoGEWasHGWfB9e3keUcZbCzTMkrA1J1diExQsAuRCQp2ITJBwS5EJijYhciEoe/Gs0SIItitZJuSnU5UaC5KdODvcZEbbDe+F+wUh3kwwXttVfJaeOb8oA2iJkStpoLcH9RqPMkkWquSJDxVzmvJ1aKt5MA0NRUkBpGJ3WDnH8H6stpvAFCQ5BQg3o3vkSWJ9IJexdeRoTu7EJmgYBciExTsQmSCgl2ITFCwC5EJCnYhMmFR6c3MdgH4ZwCXoZ9dss/dP2tmHwfwfgDP9MP5qLt/e7HjMTkhqj9mRO4IckUQdEgKpaaohlcN6YPWg9ZKUSm8KH+mCmxl8OQmSPPMIpC1Kqb9AGgEa1wtI2EkSjQK9bWoFVI30g7Tl3gRXQRBnbxelHQTLJb3gheU+NIg8iXAE16i1V2Kzt4D8BF3v8/M1gP4sZl9d2D7jLv//RKOIYQYMUvp9XYUwNHBz7NmdgDAzrV2TAixulzU3+xmthvAawDcMxj6kJntN7PbzGzzKvsmhFhFlhzsZrYOwNcAfNjdzwH4HICrAFyH/p3/U2TeXjObNrPp+fbFf8VPCLE6LCnYrV+F/2sAvuTuXwcAdz/u7qW7VwA+D+D61Fx33+fue9x9z+TY0L+KL4QYsGiwW7/+zRcAHHD3T18wvuOCh70DwIOr754QYrVYyq32RgDvBfCAmd0/GPsogPeY2XXoq2mHAXxg0SMZqDZQCzQqs7RsEclazSADKaoYF2a9MUkmKqwWZEl5lC0XZrbx8/W66Wy5KNOvGbTRsiD3qiyD9kTkfF3iX/94fK3GAlkrkvMqssZF0I8pTL4LFDsP9NJaMNGp/5H0dvFfkVnKbvwPkQ7RRTV1IcTzB32DTohMULALkQkKdiEyQcEuRCYo2IXIhKF+y8UA1InMUAZZSBWTIJy/V9WD3lBRW50qaA1VsGMGvkdZSFHLq6ilUT2QjTxYEzonyDiMkrUiWbHeJNlm4BJaacHaB0/LopZdRML0oKBndLKo5Vi0jha07Jqop9eqF1yLvXYg2xJ0ZxciExTsQmSCgl2ITFCwC5EJCnYhMkHBLkQmDD3BnBU+7AU9wJiAFUpQQS+sMpA0LEp5Iv3BGvUo64ofLspEi+oyluFByfGCA3a7fD08yK4aCwptMlWx3gyyEaNbT/C6tLt8PVgh0/Egi64IKpnWgiKQnW5wDQdZjE4kzDApksl8UTFVbhJC/CahYBciExTsQmSCgl2ITFCwC5EJCnYhMmEEtZ2JNlBwuWOcSTxBtlkvkKciySsq5McywBrBe2YVFpyM+oYFxwwStpj8UwsraQY9ysKeaEHWHpnnvainX5ARF+hQURYj878MsgOjhLgo0y+iCq7HspP2MZTegpeFoTu7EJmgYBciExTsQmSCgl2ITFCwC5EJi+7Gm9k4gB8AGBs8/qvu/jEzeyGAOwBsAXAfgPe6e2eRY8GKRtIW1duqk/ckC5JnGs0mP95Y2gcAaM21qI3VOquCXfVaVAsvmBftgkd10Jpj6R1tD6rhefCe3y75GreDBBqWdxPkzqDs8XN1gtZQjSAhiokQ7aCGG23zBaBR435EyUbtQIVokPNFbb7Y9RHlcS3lzt4G8EZ3fzX67ZlvMrMbAHwSwGfc/WoApwG8bwnHEkKMiEWD3fvMDX5tDP45gDcC+Opg/HYAb18TD4UQq8JS+7PXBh1cTwD4LoBHAZxx/9Xn6BkAO9fGRSHEarCkYHf30t2vA3A5gOsBvDz1sNRcM9trZtNmNn2+FRWoEEKsJRe1G+/uZwB8H8ANADaZ2TMbfJcDOELm7HP3Pe6+Z2p8BN/OFUIAWEKwm9mlZrZp8PMEgDcDOADgewD+aPCwWwB8c62cFEKsnKXcancAuN3Maui/Odzp7v9hZg8BuMPM/hbATwB8YbEDWVHD2NSGpK0IpJUxIqNFuR3dsk1tdSJPAUBngauHTK6xoAZdEeRNFIG00g3qqkXSUEFqqzmpnwcAnaDHU6CuoXJuHGuk5c2wRRL48ZpBXbhO0CqLaVEe1IQLW5EFtkYQTc1AczSarMOPx2o5Riwa7O6+H8BrEuOH0P/7XQjx/wB9g06ITFCwC5EJCnYhMkHBLkQmKNiFyAQLa4yt9snMTgJ4fPDrVgBPDe3kHPnxbOTHs/n/5seV7n5pyjDUYH/Wic2m3X3PSE4uP+RHhn7oY7wQmaBgFyITRhns+0Z47guRH89Gfjyb3xg/RvY3uxBiuOhjvBCZMJJgN7ObzOznZnbQzG4dhQ8DPw6b2QNmdr+ZTQ/xvLeZ2Qkze/CCsS1m9l0z+8Xg/80j8uPjZvbkYE3uN7O3DsGPXWb2PTM7YGY/M7O/GIwPdU0CP4a6JmY2bmY/MrOfDvz4m8H4C83snsF6fMXMeFXVFO4+1H8AauiXtXoRgCaAnwK4Zth+DHw5DGDrCM77egCvBfDgBWN/B+DWwc+3AvjkiPz4OIC/HPJ67ADw2sHP6wE8AuCaYa9J4MdQ1wSAAVg3+LkB4B70C8bcCeDdg/F/BPDnF3PcUdzZrwdw0N0Peb/09B0Abh6BHyPD3X8A4NRzhm9Gv3AnMKQCnsSPoePuR939vsHPs+gXR9mJIa9J4MdQ8T6rXuR1FMG+E8AvL/h9lMUqHcB3zOzHZrZ3RD48w3Z3Pwr0LzoA20boy4fMbP/gY/6a/zlxIWa2G/36CfdghGvyHD+AIa/JWhR5HUWwp0psjEoSuNHdXwvgDwB80MxePyI/nk98DsBV6PcIOArgU8M6sZmtA/A1AB9293PDOu8S/Bj6mvgKirwyRhHsMwB2XfA7LVa51rj7kcH/JwB8A6OtvHPczHYAwOD/E6Nwwt2PDy60CsDnMaQ1MbMG+gH2JXf/+mB46GuS8mNUazI490UXeWWMItjvBXD1YGexCeDdAO4athNmNmVm65/5GcBbADwYz1pT7kK/cCcwwgKezwTXgHdgCGtiZoZ+DcMD7v7pC0xDXRPmx7DXZM2KvA5rh/E5u41vRX+n81EAfzUiH16EvhLwUwA/G6YfAL6M/sfBLvqfdN4H4BIAdwP4xeD/LSPy418APABgP/rBtmMIfrwO/Y+k+wHcP/j31mGvSeDHUNcEwKvQL+K6H/03lr++4Jr9EYCDAP4NwNjFHFffoBMiE/QNOiEyQcEuRCYo2IXIBAW7EJmgYBciExTsQmSCgl2ITFCwC5EJ/wfUZjIsdjFZ/gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import mindspore.dataset.engine as de\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "train_path = \"./datasets/cifar10/train\"\n",
    "ds = de.Cifar10Dataset(train_path, num_parallel_workers=8, shuffle=True)\n",
    "print(\"the cifar dataset size is :\", ds.get_dataset_size())\n",
    "dict1 = ds.create_dict_iterator()\n",
    "datas = dict1.get_next()\n",
    "image = datas[\"image\"]\n",
    "print(\"the tensor of image is:\", image.shape)\n",
    "plt.imshow(np.array(image))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到CIFAR-10总共包含了50000张32×32的彩色图片。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义数据增强函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import mindspore.common.dtype as mstype\n",
    "import mindspore.dataset.engine as de\n",
    "import mindspore.dataset.transforms.vision.c_transforms as C\n",
    "import mindspore.dataset.transforms.c_transforms as C2\n",
    "\n",
    "def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=\"GPU\"):\n",
    "    \n",
    "    ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)\n",
    "    \n",
    "    # define map operations\n",
    "    trans = []\n",
    "    if do_train:\n",
    "        trans += [\n",
    "            C.RandomCrop((32, 32), (4, 4, 4, 4)),\n",
    "            C.RandomHorizontalFlip(prob=0.5)\n",
    "        ]\n",
    "\n",
    "    trans += [\n",
    "        C.Resize((224, 224)),\n",
    "        C.Rescale(1.0 / 255.0, 0.0),\n",
    "        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n",
    "        C.HWC2CHW()\n",
    "    ]\n",
    "\n",
    "    type_cast_op = C2.TypeCast(mstype.int32)\n",
    "\n",
    "    ds = ds.map(input_columns=\"label\", num_parallel_workers=8, operations=type_cast_op)\n",
    "    ds = ds.map(input_columns=\"image\", num_parallel_workers=8, operations=trans)\n",
    "\n",
    "    # apply batch operations\n",
    "    ds = ds.batch(batch_size, drop_remainder=True)\n",
    "    # apply dataset repeat operation\n",
    "    ds = ds.repeat(repeat_num)\n",
    "\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义完成数据集增强函数后，我们来看一下，数据集增强后的效果是如何的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the cifar dataset size is: 1562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the tensor of image is: (32, 3, 224, 224)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO2df6wsZ3nfP0+uB64nuV7wXcArc09tHzlYBKkOQRCJlKZNkwBK61A1FFQlDkV1kEBKpFSKSaIWRYqUpCFRorRIjrCAiPKjIgkWok0slIj0DwjYMdjEGHyMc3zxYnevk7m3nXvJ3JO3f7zvu/Pse2b37Dn7Y/aceT7SnJmdMzvz7u6833ne533e5xXnHIZhdJfvaLsAhmG0i4mAYXQcEwHD6DgmAobRcUwEDKPjmAgYRsdZmQiIyOtF5DEReVxE7l7VdQzDWAxZRZyAiJwCvgb8MHAe+ALwVufcXy/9YoZhLMSqLIFXA487555wzv098FHgjhVdyzCMBbhmRee9EXhKvT4PvGbawSJiYYuGsXpGzrkXpTtXJQLSsG+ioovIXcBdK7q+YRj7+ZumnasSgfPAOfX6pcDT+gDn3D3APWCWgGG0yap8Al8AbhWRm0XkecBbgPtWdC3DMBZgJZaAc+6qiLwL+BPgFHCvc+4rq7iWYRiLsZIuwkMXwpoDhrEOHnDOvSrdaRGDhtFxTAQMo+OYCBhGxzERMIyOYyJgGB3HRMAwOo6JgGF0HBMBw+g4JgKG0XFMBAyj45gIGEbHMREwjI5jImAYHcdEwDA6jomAYXScI4uAiJwTkT8TkUdF5Csi8rNh/3tE5Jsi8lBY3ri84hqGsWwWySx0Ffh559yDInIGeEBE7g//+23n3G8uXjzDMFbNkUXAOTcEhmH7kog8ik81bhjGMWIpPgERuQn4XuDzYde7ROTLInKviLxwGdcwDGM1LCwCIvJdwCeAn3POXQTeB2wDt+MthfdOed9dIvJFEfniomUwDOPoLJRoVEQy4FPAnzjnfqvh/zcBn3LOveKA81iiUcNYPctNNCoiArwfeFQLgIgM1GFvAh456jUMw1g9i/QOvBb4SeBhEXko7PtF4K0icjt+2rEngZ9ZqISGYawUm3fAMLqDzTtwKLK2C2AY68FEYBpV2wUwjPVgImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxTAQMo+MsklkIABF5ErgE7AFXnXOvEpHrgY8BN+GzC73ZOfe3i17LMIzlsyxL4J85525XWUvuBj7jnLsV+Ex4bRjGBrKq5sAdwAfD9geBH1/RdQzDWJBliIAD/lREHhCRu8K+l4QZiuJMRS9O32TzDhjGZrCwTwB4rXPuaRF5MXC/iHx1njc55+4B7gFLNGoYbbKwJeCcezqsnwX+CHg18EycfyCsn130OoZhrIaFREBEvjPMSIyIfCfwI/jJRu4D7gyH3Ql8cpHrGIaxOhZtDrwE+CM/GRHXAP/dOfe/ROQLwMdF5O3ALvATC17HMIwVYZOPGEZ3sMlHDMPYj4mAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxTAQMo+OYCBhGxzERMIyOc+SkIiLyMvzcApFbgP8EvAD4D8D/Cft/0Tn36SOX0DC6zgAow1It//RLSSoiIqeAbwKvAd4G/F/n3G8e4v2WVMQwpvED1CIQheBootCYVGQZ2YYBfgjYcc79TUg1ZhjGsijDOgPysB3XFT6B3wIsyyfwFuAj6vW7ROTLInKviLxwSdcwjG5ShSWKQA/fRBgA20B/sdMvLAIi8jzgXwH/I+x6H75otwND4L1T3meTj2wYp4FTbRfC2I+2BHph6QNb+Jq2HfYdkWU0B94APOicewYgrgFE5PeBTzW9ySYf2Tyi/yk2Oa+0WxwjEtv92hLQS/zRHuZIjsNliMBbUU0BERnEKciAN+HnITA2nDPADfj76TL7fU8mCC1SMukLiJbAIKy1cs81/9ckC4mAiOTADwM/o3b/hojcjp+j8Mnkf8aG0mf//XQtkw+aKAp7LZXxDP6+/1pL12+N+HTPwxIFYIBvEqQ/0CEdhQuJgHOuBM4m+35ykXMa6ydaAWlzIC7XhbUWhEodt2pROBXKdg4YbEO1A99Y8TU3ilJta8fgFrA9gHK4/0cbzX/6ZXURGscY3YzMw+s82R/vsfRei02HgtUJwl44fw4w9N7mTqG//Cp9HX4l7S/oYyJgHJ0shzyDLPPbWQZkcE0OlwsoR1AWdaXX1sJl6v3L9iFc8sWgKjvqnyiTZSwIU0RgwNxqaSJg1OS+4vd6cG0P8hzyHvTyjLyXU4wKihEUI7gc1gV+6YV1KgaXkkukXZCHsRyeYyVRs8eDRgHA/wlCPSECTdGEU6wDEwGjJvOV/to+9PvQ62f0+316/T5n+30uDocUwyGjYUExhCKHYggXKn/PxYofxeAi/p5MmxV6nYrEQRz2+BNDGjI8yxKIQmAiYByW2ATo9aE/yBkMBtwwGDDYGpANBrDbY9Tvcba3y4V8SJHDKIN8F8qqdhym1gHU9+xVJh9kna3URyEVANgvArH3IO7TPNx8WhMBoyb3zYBeP6M/GHDD9hZbW1uwvQVb56Dfo7+b088zRjmMsqG/7yrvoC6qycofBSF1KlbUDsUzmBDMTdoU0E96bQnE16kITMFEwKjJIOvhzf+tgReA27Zh8DJgG7ai1xD6WUVOSV4V5GWo/Lt15U99A+mSMRkDY0IwB00+gTimQIuAfj0HJgIGV4DHgGzHm/VlNaSsSi5XJbdUJb2qgK0RjJ6AnV3Y3aEY7jAaFlwsoCihLCe7s7Mp21XD6846+w5L7P/fpf4S09DOtB93DkwEDMB73h8Byt3Q/18VFNWDlFXBLdWIQVVQ7e5SDHcoh7tcHFa+d6DwXYZVNVmhY5yBvlerhv3GIRkxOYy4yTJIva8HYCJgjHkOb75f3I0WAZTVDkU5oqhGXB4NKYcjqlGIFwgxA2UFmbrhsmQd0darcURiWHAUgEMEBU3DRMCYYA8fmz929FVQlAVF9TDVCKrCL6jtqoCs3N8U1b6paAXoUbHGESmBHbwAaB+AVl9zDBqLcp4QDFR6IbgYnzzB05clHr9r8ZU83ne5WqIAXOZQ96ZxELEvNhJ/gJzJH+MATASMqVwCvlTClx6E6x+s7604ulBX9D4+BiBW8hizcpbJiq97B4wlE82sKAwmAsYyeS4sZ6ijAPtMWqDRGshz6GVwNoN+DtXQOw6b8mReVtfo5JiAVWKOwY6TsZK+t0tM9gBEq3PsD+hBL4frcuiH7SqDq8HZGAOFYni77sWqaC9XQZeZK8dgSBj6rIg8ovZdLyL3i8jXw/qFYb+IyO+KyOMh2egrV1V4YwoZCyefnMUV6iAgfck8D1ZAH3oDv7AF/QGc3YJ+5ot1ljq8vYcPLIrNB2P9zJto9APA65N9dwOfcc7dCnwmvAafc/DWsNyFTzxqrIsoACuuUZeoQ3/Bm5R5GIF4Xc8PQBonvohCMPDNBC0E0bcQmxKnV1tso4G5mgPOuc+KyE3J7juAHwzbHwT+HPiFsP9Dzs9q8jkReUGSd9BYJXoU2YqJ7fuMevBRHH6c6RRY1Np0GR+Q1MOPMtTh7lep/Q3WLFgfi/gEXhIrtnNuKCIvDvtvBJ5Sx50P+0wEVo0eTx6tgWLmOxZCD1nPMj+sIO/55gDaEsAf2IsBSGUINKKOddHBbzk2lmCdrMIx2DQF0b6U4iJyF765YCwLnYZ6TSIAk5bAtT0vBBOJMFU4az+IQFFBr6if/D11mInAellEBJ6JZr6IDIBnw/7z+JyQkZcCT6dvtnkHVkBT9pkVsqcuVxQ+uOhizycb6ekgghjUEnoHyIL/oPAJTnV8i3YQPoOxDhaZgeg+4M6wfSfwSbX/p0IvwfcDhfkD1kTTaLIVE7XmMn4cwWjkl2qEbwAO8Ta/Kk+W1U7EPv6JcY56Qp3b8FNcn1l98Q3mtARE5CN4J2BfRM4D/xn4NeDjIvJ2/JCGnwiHfxp4I/A4/qd/25LLbEwjHVa6BhG4qi5ZlN4auBD8A4MYPBAtgdIHDcVmQ/RdlhX0yv05Bwrgr1f/ETrPvL0Db53yrx9qONYB71ykUMYCpEKwhsvFS17Ei8BIZSruh+whVVUPN47NgZjJuCpDIFFVb18MWYouYM2CVWMRgyeNfdloV3+5dM6LbKQqeIgrnhhqHDIYUXmLYSwQFVwN5S52/LlGmAismm6LgE7OcFLS27RoCRSEnoIKP9Q4CkAG12TjzGT1vAYZk999VefNHAX/whB4CSYEq6TbIjBgMgtLmpVlzvRMG8caRU1rThQB8DuqMDtWLwQREZyBmXqdljNaDLeMYDjyASc7mAiskm6LQNPEjSdhjKtOQbViYlxQHEgUhwlXwMUSRjt+JGGv7wcUXa2g6nlfQQ77v+8kLdFJMdA2mW6LQNMdFqNVjrsYrDBIKHIaPwYgTXOfXvpiCWd3oVIH5lWDCOjtaCkcV2vsGNFtEZjGSUiEtwYRSCe7SXMF6O0SgtfQNweuK9XX3JQaK699CldX/1E6jYlAE+NRMZg9OoNoBdyA78rTDsKC0GVIkh6/8DEBvSokJdHJCCYSE3ixWIeYdR0TgVnoLJnGBGeo/QFNlsAFvMtljyQJSQn9kKG4xPcUTGQoTSyBk9Rxs6mYCMwieq+PkMv9pKObAn3qB3a0BIbUw4GfI8k7WMKgUhaCFoA0XZGxckwEppEmz2+6Iad5toeceDM2NgXiOjT3Ad8MSPMFPqOOvSFGCOKf9pm2BMKSxWwja+AU3c5fYCIwjSHNbdWmnO7zCMYJ4nommwJn+zAY1SnHYvv/G+o9340fGHQbsL0NW9uwtcX+JIX4E5Qx1BgzvlaNicAstEcrvVmbhCC+jsee0Lt3XPnxMQD9QRhFOJr0qWZ4v0CYztSPDrwNtm+Dc9vUCUdSqslW2An9GjcGE4F5SO/Epil2OiICOjZgAAwG0N8KTrwcGHnnX8wNcBYvAC/DV/7t27wQ9LbxU54XhR81VE3W/MuVWQLrwkRgXvTdmFoC6esTHOgSrYAb8MlEz4YMQoPg+ct6kBeQj4IDEDiX+8q/tZ0IwNaWHyBQFCHnWC0GVaX8Bu193E5gIjAvTQIwo497XU6tdROtgD6+GdCPacRyGJSQ9+G6wncDjgoYFL7yn9uGwTbk25kXgO0tb0LE6KFRETKRFLUAVGsbB9VpTAQOg+7wjutpInACnYPRITggpBDvQxbzCIaooF4IBhqUPsnIhRJu2IL+NrCV+6f/VhSALZ90kBAQUPrnflmVVFU9O5GJwGo5UARE5F7gx4BnnXOvCPv+C/Avgb/HD/J6m3Pu70Ja8keBx8LbP+ece8cKyt0esyr9CbcEogCcxTcFxlbAgHpKoWC/55X3DQwqvEgM+r7yD7Yg3wpvOgeEbKMl/k2UVFXpT1WdyBbVxjGPJfAB4PeAD6l99wPvds5dFZFfB96Nn3MAYMc5d/tSS7kpFHjJi0PlmnwC0SXeNELxGHMKP1NQTsgFoP+ZNty1b0R3F0BIGJAkQizqpSon3QProMsxAjCHCDRNPOKc+1P18nPAv1lusTaYODd8pEkMRumbjj8xdmf8cXVIdZMA6O2oHGPU0KLxTKUVFD5Zaak6DIzVswyfwL8HPqZe3ywif4UPHPtl59xfNL3pxMw70BH3dZzGIKfOEATs7zWhaTubtASqCp895HLITlqNI43K0g891sFCxmpZSARE5JfwIz0/HHYNgS3n3AUR+T7gj0Xke5xzF9P32rwDx4fYFBhbA03NgWkRlJl6U+z7qyrI9jcFisJrQllmVFVllsCaOPK8AyJyJ95h+O9ChmGcc992zl0I2w/gDefvXkZBjfZIJw7NdBMopWEcgB8qqIh9gBS1CISn/+XYJCibDQxj+RzJEhCR1+Mdgf/UOVeq/S8CnnPO7YnILfiZiZ9YSkmN1tCzms3VHNjXa5I0B8pgCVR5GFNcURYVZVnVTsGyMy2t1pmni7Bp4pF3A88H7hcRqLsCXwf8iohcxTtd3+Gce25FZTfWRJxGfNw7gM8ePCEETQIwtgJSS4C6lyBYAmMroMwoy4zL1hZYGxIs+XYLYT6BjeUfA68EXo2KEgzhwv0B9RTkTTETOvuoSklchkCiooS/ethnFb4w9Os4c9kFJnMSzMPpsK4O+b4O8YBz7lXpTosYNKZyPfXEwgN8xuBeD67NkyZByf704TpGwEcDUxShwhf1vAKPfBWGZT1l4bfwSUiOgr5kbEqYGByMiYAxlRuoIwQHA7iu5+cLyHuq2z/WthhAFScf0H6CEQyH8K2w3h2F7RE8SP3UTxORHIbTTCaC0pxoIVjCaFUTAaOR09QDhbYyONf343yyaAXo8RFp6rWku3A0hG8NYWcXdnfhiZHvNtpleROOzuqwgBMsBNssnMnKRMBoJA4JiEOGB3qGkXSA1AFu/NHIV/6dHXiigK/iB5csa1ahU6juS2rDRBfvRBLGYI05ohCYCBiNnCUIQbACxs4/krWuYVW9qsJGRXD67cJTQQAe4ejt/iZ0Z4Qe6Fmq/8MJswZ61A6bSKp+c2IiYEwltfKB6TkUtG9ABQaWlXcIFkU9H8EqcrDqsqZDlBbxNWwsBd6TGrtTRhx5yKWJgDGTCsYzAQHTuwLH0wz5N8XRgBdLLwAXw/CAy6zuiVzhY9i1EJxIAYjE3O5xDvcjYiJgzGaWAOjQYJgIGqqqEA8Qw4Cp5yRYNtolUSXLiSZOPDtc7DRHHjtgrIkWA+djRdo3F6AWAD2wQPUWVCECsBj5mIBVNgV0eXW5T5QPYBoLCgCYCGw+LWUoGj9Fs8ntxgFCTSIQLIGy8E2AODfhqs3zKAAnuhmgWYK5YyKwyWxivkItBNEK6PX2lbMKAqAdgqtOFab9Acb8mAhsMmmbuwUqmDE4CB9CGEVAz7cQHINaANbRFEi3jYMxEdhkWrQEoglflCHOX9fmdIxAsmQ9rwu93qTLYFVatkfdY1bQEV/AEjER2FT0k7UFIYh1/mIFoyAEpcomDKFMWTXpFwi1vteDfljO1rtXRhSCzvgCloiJwKain65RCNZINN9H+Hb9hZD6a6oIJD0Fec8PONLWwHX4EN9VYRbA0ThQBETkXhF5VkQeUfveIyLfFJGHwvJG9b93i8jjIvKYiPzoqgp+omkwsddtCVyitgaiJVDogKCITh6SiEGv58cd9PqThoKxWcxjCXwAeH3D/t92zt0elk8DiMjLgbcA3xPe899EZJXifzJJK/9BQ+RWxEXqJkERswDvswTCtrYEogj0lTWQ1UJgbBYHioBz7rPMP97jDuCjIeHoN4DH8UlpjHlpagZMJPxfH2NLgGAJ6CZBZCwC+y0BoiUQmwbMKQI9Nqtb9ISziE/gXSLy5dBceGHYdyPwlDrmfNi3DxG5S0S+KCJfXKAMJ4vY997UFGihSRD9AgVwIVgDVWoJ6HkZtQCEzKR5r85IlGfeLzCVDD8q7jb8OPkBZjqsgaOKwPvwP9Pt+MDF94b90nBsY/5A59w9zrlXNeU86ySpADTF6K+5Qb1H3SQo8DlBL+hBANFMYDckDaQezKJSihUhs3hZHRAw1JuymBCslCMNIHLOjfNBiMjvA58KL8/jZ5mMvBR4+sil6woNXWwTT/5M/a9irdOc6fpelMFJWPqpxycigVTFj2JQhRRio+hPoGEcQiRHzXnO/qHKqx540GGOZAmIyEC9fBM+TwTAfcBbROT5InIzft6Bv1ysiCeYjPqmT598qVMwdbytifjgH+GtgCKMDJx0GDQsBQxj12K0BJgRzTfNCkgXY+kcdd6BHxSR2/Gm/pPAzwA4574iIh/Hp467CrzTOWfdtynR8TWtKzAVAZ1Bs6/2rSEg/wqqSRAChy6UMCghH4XypFZAiDAcZxUO4tHUwwjsF8ImS0BbRSdwwtc2mWdW4rc27H7/jON/FfjVRQrVOvGmW2YFa2rvp23/WUtsF6eWwYJJJuehwDv0dBhxUUAeLYEeE1ZBOfJ5BUehOXBR+QP2WQITA5GoxWBWYcIMZsZy6HZSEf3EaaqIMX3TojdczAU3bwVPt/X/YxafNJ/cCitFibcGdJPgQgGDKU2CUZhbYBgtAWqjZZ+uphaANvvThCYRnT/MWJhui8BrmP30HVI/bXePeI0Bvh8ligDJNWDyBm96DbVTsKIxkw+wMiGIcfkTzsFQwfsxu02o6aMCvhUmGRkVfk6B+N7LNNTbVAC0YzB+nnQ7ikBbzYLoxFwgr98m0W0RiJWqqVJqj/QyK1dThT+o/z/e+JnajhTU5vgKrYFLhElCKl/x+zn0RtDL8RORDP3EIk8MYTdMMPJU5ecXGOJnFmpMMa5FJPo5ouBBs7Wk948wi2BBTATSbX2TTbVhD0FTBZ/m9JpVvmmmcRSAHisftD/CjwgcFnDdKAT/ZLCd+Yr/VBCAnaGfinoXHzk2xIvI1JNG/4j+vtPhyroppL+3Hgsn2uw63RaBeEOlT5JYyeLTddlPmlkikFZ8bfI3CUHJpACsUASeo54yrFf48QB55i2Bp2LlH/rKv4MXgG/Mc+LoXIyfQWUt3mf9aAerFsB+eN+Q9VkGB1lwx4Rui8BBuahW0QWXVmQtAlWyrY+f1WSJFWgN1sAQP9NPL3QRZpmvjzsjLwBxerEn8JFjc6EFTFteqXNUb2sR6Ktz9EIBVt17cEIEAEwEJtfp/lVYAZG0dyAKQNWwPa0LMVYEXQGazOklEn0Du0BWQjaErFpAACIx5iAKQZMAxLUWPN1TUqpjdliPEDRZkscME4FUCPS+VVkB07oCU0ugyWpIK0RqBayhWTDURSjhmhHsVF4AjjzHoG4KxMoMk21/LXrpBAMVnK7gSqyUq/wOmppsxxgTgXSmimXOXNHk9GsSgNQn0NQkSAUgLk1WgK4oK2CPWgjAXytaAAtNMqpFQH9H8bPqQCL1e50B8sr/a6eCvfidDFlNF960rt150NGfG0K3RaDJo6wr5aLNAW2mRudXvAH0dnpTNDzlJpoJ08qsxSJ621fEFepIQh0MtBDpZ40WQdP3kLwtrveaLKxlox2Q894f2ppJP2PLgtBtEdABKPpJm7Y/F/mhomk67ek+q/Iflnm6HZdINEBy6vEFC6ObA6kQ6GaCpql3J/6GqxLCwyheFIC46OZMzv7mzZrptghE9A+ih7RO8xnA/An19THRnE1/7FlPu2kWQSR9+ushx4vGOBzAFbz/LT4YFx4pVqolCmcUYr0dmdbU0t9H26RDpGNPRpNTM1qMaxYCE4GUNIHHtCCdGDOvY+ebKlxqyscbOVb8Wcz7dEhv/jVaA3v4j77UoaJpMyB9DZNWVJZ8Tel30aa5rXtxdDyDFjvNiptxTZgIpGjTrWlMf9zWc8Pnat+0H1A3C3S34KyFZD3rZm56Cq6pL3tpuf51G3mWT0AzyxrYhC48baFpS0DfC6nVYyLQMrES9dlvYudwOoc8h+fi4KIYnKLb/NMqbJMJOOsGTYVgVpnjOm0aHEcOEsdMrZtIv4e2BvnoMmhLIFZ67SvSMRFrLvM8SUXuBX4MeNY594qw72PAy8IhLwD+zjl3u4jcBDyK7y4G+Jxz7h3LLvTKiO2xWEl133RYX+nBlTjgRbvFm8zUJmb1BKSkN8i0mz6+t6mNedyIYwm0VRA/S3yKJk/6K5n/TUr93cffbUBtha2btOs3jYXQD4L0nlgj81gCHwB+D/hQ3OGc+7dxW0Tey+RXvOOcu31ZBVw7BfBVahFoSnqpk2ikIgDzC8E0p6M+R9q+n9bW1+c4ziKgnWNakHUMRBozEdZXtONQW3RQxwys8zvRTbImEWh6AGxi74Bz7rPhCb8PERHgzcA/X26xWqbEC0FTxhtt0qWqDpOVtekHbbIC0v2RWUFF6Tl1hdHr40h0sqYOtKYcjGl6tvi5owBox5zuqVk0BkTndNDE12myWC0Cqe8jPc8GWgKz+CfAM865r6t9N4vIX+G7jn/ZOfcXC16jPeLNEj22UQiaunjS9upBzHL4pTEF00jNzVQIjqsIQLMA6MCEPNnWIcVQO3jjOjYJogUXLY55xaDBPzTTb5E6Z+PvQsOxLbOoCLwV+Ih6PQS2nHMXROT7gD8Wke9xzl1M3ygidwF3LXj99VBSj0yL7VKY/mPO8krPagY0ebr1+ZLlVAV70/wLx7Ep0IRO7BKFQFd43UTTzsL4PfUazqXP02QV6O3EJzTh5JvV36+bKqj/b8jTX3NkERCRa4B/DXxf3Oec+zbw7bD9gIjsAN8N7JtlyDl3D3BPOFfjBCUbR1P7P+UoloDe1xSXMKUpkGUhTLap7bwhInCKJcUQ6MqVVsAoAKllkDYZ4pM/JiLJ1Rqaf7cmf5AWgSJZp118aa/RBrKIJfAvgK8658ajRkXkRcBzzrk9EbkFP+/AEwuWcbOoqLsED1pmnSNdH2QF6O2sfnml6VwbYgmcYjJKdiliUDHpM2iqXHqEpfbnxB6dVCxmZSVK/UH6tbZSdMp1Lczp75Jafel6Ex2DTfMOOOfej599+CPJ4a8DfkVEruJ/83c45+adzPR4cVRlT5/a6fiFpu3YxtTnyMPl03O1EGzSRBSAWN+0xbw0MWhap4FdoaKf6iVNp7QPv8lHUzGZADUVgVmBSk0O2ibzP72HWhDvo847gHPupxv2fQL4xOLFOuFoT3HTE6FJBPQSHq9VvNlSJ1rLIqAF4LoMqmp/139clhZtqNFCEAqS90Kew6aAKu0TSAuYZkJORWCa9ae7K9OKPU0AovNyzVjEYBvogJe0/dhkETQ5HvMwbr5paZmxAOTQ63kRiMu1FVyt6o906AxEmvQ7ieukkp/q+SjPMvpQUitAd9ulzYx5REBfP24Xyb74usmpDPt/4zViItAG8YbT7cBpAqAdjU3m54aJwBl83YoC0O9PikBcCOthecTmwbQKo7+3UNHzPNT3DIo+7MX2Sepg1IFKOkoxVny9Pc3vk+6LlXtad3DBuxUAAAwhSURBVO4G9OiYCLRF+kSYJgDTbrQoIk0+gZY4zaQA9PqTlgDsF4H+7oLZiDQN1sCpHPLMJ0cdW/x5CP2G/SKQCusUATgDXMrq1xPX1+gK3iQEGxDTYSLQJtp5NE8vQ5O1sCFWwCnqp22vB3m/FgJd6dP12QzKasa8BIclsQKyXv0y/XdJEAMtCHEdfSs6D0BoAsTPWhFClfvq5JFYseOFpgnABjTfTATaJu02OqirMW0mpIEqLREF4Lq+d8KNLYGQnCVWfBIh6PXg4uiIvQYHNQmCFaBDBjK8VXCVya+vIDQTBkwO7dUiEIg+j8gVJv8/Llu8gL54+v8NwERgk2hyAGYN+/R26sxqgVP4ypb1w5M3rhvCayvV9q0AsmYreiZp5Rwk27GWZnU9HLG/4jf2UDT5WJJ2farb4/frpkEkfeqngUkt/m4RE4FNpulpsYFNgTz3T/9rmwQgjbEPFaGK5Y1m+7wOwjgwaECzAESzPTx9myr91K7JdFBY+iYlApqJqEgdPzBv703LXbomAseBeOPopkCMfW9TBHI4Hbrfsryu/Hmv3r6mB1d1kyUKQFxn4b3lATEDaWKOVATiWsX4n6I29aNBtU9o9BN5nrH+ydsI5544bxSqtKmWBnKlfoOWMBE4LujgIi0EbTQFkv73PPdWQLQIsnSEH0xWNiUEWToMuIk0Yk9X/AGc6kM+8EuVqYUQJZgllXRa95wWgqZjko8C+5v5ExZBFIJ0mrSmMqTNvjViInCcSHMUrtsfkDFRufNQ8a9TYjC2CHQF10EyMPaWV/jKOhaDJhFI++nDcnoAWaj4eR+uDesSP4a9xPc67EFzhZ4lBPoJPsMS0H4Bku2xs1B3NerApGlLC5gIHDeiEOin5zpMyWQo7Wn15NcCMBaC2NZXlTt9wFL6/18T1qdiFGREVfpTfe9vyNVT/9pBvR0DgkaoJnkaR3GQAMS2QyoKSQVt8l2kzv8rcWdMQKMHUswqQwuYCBw30ifWuqwA7XAL7f685wODxkIQ2/dqrQc+ZYTegVBjKmXVZJW3IPZihQiV5lToZpx44qdL0g1YAZfD9pX4faVP4GmiMEssFKkVkK5LghBEK0YLQTxBg8O0DSEwETiORKfgOi0BmLhpnyuT+7cMVkBY52UtCpS+8lehgo2jB/FP66tREHQsv6rVsX1f4W9YHX5cKq99PCfhuAzqJy/sf6qnT99UTJuEIKtPxeSlJ4Rgn5MzCl4csahHVKUnWjMmAscVnbZ61aQ3aLhxL2W+YpdVEIAKrs2CEGTeEsgbzN9KVcQygypX+3K11kE2QQyyKtSdcHxe7X+IRyaiAuN5dH+9juSrku1pvoLwJI+V/Aq+F2LaVzXemTcs8fqpGbFmTASOK9qTvWqijZ0+RQtvbl/JoahCM4DgG4jbWX2TZeF943s+WAFjIdDX1JUlm7zsNdGSqKBUFVM72LVhQRbCe8eFoBYC/QZdAH1BLQQ6VDCwN2V7jLYCku7SxmjCNTNPUpFz+HTjNwD/ANzjnPsdEbke+BhwE/Ak8Gbn3N+GDMS/A7wR/1F/2jn34GqK33HWHWSiRUA/TSs/RHcvC0/IEp4LN/XpTIkDajuIQ6l9A7Gi6F6IaAGkgTqhKZBHIcj3i0A0KAjrIlc+h/gZ0jc1fd5UCHIOj1alJgHYZBHAR1v+vHPuQRE5AzwgIvcDPw18xjn3ayJyN3A38AvAG/BpxW4FXgO8L6yN40raGR5J4xbYv32F8BQOy6msduRlTVZANMvzyXWVmOpZWI+FgP0P9niKa6kdhWUGZa66DnN1zfTp39QcOIoIxMJEh2f0eWyAAMB8mYXijHs45y6JyKPAjcAd+LRjAB8E/hwvAncAH3LOOeBzIvICERmE8xjHjfSG1xUlU+tI07aq0Hu5HzFYBjGIlbvKkz790GU49gckxSjxlZ9gCcSkIRpdv3TdK/LgK9DOvmkVMRUC7bQ8DNOEYAPE4FA+gTAJyfcCnwdeEiu2c24oIi8Oh90IPKXedj7sMxE4Lkx7sjcJwLT3a8aN8/r1HsGdEY4dZ/zRDkL9ONfNgcQKiNuxqaGLnb6OLYAi7LiSHqSZZhFoC2ge4rFROHTa8w2wBuYWARH5Lnz+wJ9zzl30Tf/mQxv27UspfqzmHegiTSZy+v+UtDIdUKn2ZgXO5EEY8joT0JVw/CnqCh+t8wz8WIV8uhjEaMKigHIEV3S24OhkjZmMpzldm5RlWkVOg43SZsW+6Kl2mEsERCTDC8CHnXN/GHY/E818ERkAz4b954Fz6u0vBZ5Oz3ks5x3oAk3t/yav+UHnOOjplvbrNYjARF+6Eow9aidkHCQ0vmQZRCCDa7J6O8M3G8oRlAVciinC9XySZbJ9UM/LNEGY9vl0z8AB4xPWyTy9AwK8H3jUOfdb6l/3AXcCvxbWn1T73yUiH8U7BAvzBxwTpjkAD0va1i6T/drKmBZLn/apN1WUIAiXCNZBaBLo7skYiJRlcHlUiwBNIlAl2zqQZ9rClO1ZApdaAy0zjyXwWuAngYdF5KGw7xfxlf/jIvJ2/HQcPxH+92l89+Dj+I/4tqWW2FgN+uZdFrrCx2vofbOeiLFPfp6nZbQO4iFBBCbEgFoA9vS08nG76Wk9zWk4SwDSzzlLBDbACoD5egf+N83tfIAfajjeAe9csFxGWyxbCKB+wubqdaw8qSWQvk9bA72GY5Ljx2KQJf7FLPgBYqXXyzRrBKa3/2nY1mvdlErP2eQTaBGLGOw66c29SvO0TNbzHB+PTUYxTp0fECbE4Ap+4tac4AeIy1Btz2JEnXcwrnW5ZjlDU59KXOspyzZADMQ/uNvFHIPGQsTutyZLQb8+qMLPS85kYtFsyvY0h2rsgVg/DzjnXpXuNEvAOP7E5sa6wqhLvBfshPAdbRfAMIx2MREwjI5jImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcUwEDKPjmAgYRscxETCMjmMiYBgdx0TAMDqOiYBhdBwTAcPoOCYChtFxNiWfwAj4fywv7UMb9Dne5Yfj/xmOe/lhtZ/hHzXt3IjMQgAi8sWmrCfHheNefjj+n+G4lx/a+QzWHDCMjmMiYBgdZ5NE4J62C7Agx738cPw/w3EvP7TwGTbGJ2AYRjtskiVgGEYLtC4CIvJ6EXlMRB4XkbvbLs+8iMiTIvKwiDwkIl8M+64XkftF5Oth/cK2y6kRkXtF5FkReUTtayyzeH43/C5fFpFXtlfycVmbyv8eEflm+B0eEpE3qv+9O5T/MRH50XZKXSMi50Tkz0TkURH5ioj8bNjf7m/gnGttwc8juQPcAjwP+BLw8jbLdIiyPwn0k32/Adwdtu8Gfr3tciblex3wSuCRg8qMn0/yf+KnoPt+4PMbWv73AP+x4diXh/vp+cDN4T471XL5B8Arw/YZ4GuhnK3+Bm1bAq8GHnfOPeGc+3vgo8AdLZdpEe4APhi2Pwj8eItl2Ydz7rPAc8nuaWW+A/iQ83wOeEGYgr41ppR/GncAH3XOfds59w38BLmvXlnh5sA5N3TOPRi2LwGPAjfS8m/QtgjcCDylXp8P+44DDvhTEXlARO4K+17iwjTsYf3i1ko3P9PKfJx+m3cFc/le1QTb6PKLyE3A9wKfp+XfoG0RaJrt+Lh0V7zWOfdK4A3AO0XkdW0XaMkcl9/mfcA2cDt+mtH3hv0bW34R+S7gE8DPOecuzjq0Yd/SP0PbInAeOKdevxR4uqWyHArn3NNh/SzwR3hT85loroX1s+2VcG6mlflY/DbOuWecc3vOuX8Afp/a5N/I8otIhheADzvn/jDsbvU3aFsEvgDcKiI3i8jzgLcA97VcpgMRke8UkTNxG/gR4BF82e8Mh90JfLKdEh6KaWW+D/ip4KH+fqCIJusmkbSR34T/HcCX/y0i8nwRuRm4FfjLdZdPIyICvB941Dn3W+pf7f4GbXpLlQf0a3jv7S+1XZ45y3wL3vP8JeArsdzAWeAzwNfD+vq2y5qU+yN4k7nCP2XePq3MeFP0v4bf5WHgVRta/j8I5ftyqDQDdfwvhfI/BrxhA8r/A3hz/svAQ2F5Y9u/gUUMGkbHabs5YBhGy5gIGEbHMREwjI5jImAYHcdEwDA6jomAYXQcEwHD6DgmAobRcf4/ZU9FRHDzNy0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ds = create_dataset(train_path, do_train=True, repeat_num=1, batch_size=32, target=\"GPU\")\n",
    "print(\"the cifar dataset size is:\", ds.get_dataset_size())\n",
    "dict1 = ds.create_dict_iterator()\n",
    "datas = dict1.get_next()\n",
    "image = datas[\"image\"]\n",
    "single_pic = np.transpose(image[0], (1,2,0))\n",
    "print(\"the tensor of image is:\", image.shape)\n",
    "plt.imshow(np.array(single_pic))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "cifar10通过数据增强后的，变成了一共有1562个batch，张量为(32,3,224,224)的数据集。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义动态学习率函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义动态学习率用于ResNet-50网络训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "\n",
    "def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):\n",
    "    lr_each_step = []\n",
    "    total_steps = steps_per_epoch * total_epochs\n",
    "    warmup_steps = steps_per_epoch * warmup_epochs\n",
    "    if lr_decay_mode == 'steps':\n",
    "        decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]\n",
    "        for i in range(total_steps):\n",
    "            if i < decay_epoch_index[0]:\n",
    "                lr = lr_max\n",
    "            elif i < decay_epoch_index[1]:\n",
    "                lr = lr_max * 0.1\n",
    "            elif i < decay_epoch_index[2]:\n",
    "                lr = lr_max * 0.01\n",
    "            else:\n",
    "                lr = lr_max * 0.001\n",
    "            lr_each_step.append(lr)\n",
    "            \n",
    "    elif lr_decay_mode == 'poly':\n",
    "        if warmup_steps != 0:\n",
    "            inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)\n",
    "        else:\n",
    "            inc_each_step = 0\n",
    "        for i in range(total_steps):\n",
    "            if i < warmup_steps:\n",
    "                lr = float(lr_init) + inc_each_step * float(i)\n",
    "            else:\n",
    "                base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))\n",
    "                lr = float(lr_max) * base * base\n",
    "                if lr < 0.0:\n",
    "                    lr = 0.0\n",
    "            lr_each_step.append(lr)\n",
    "    else:\n",
    "        for i in range(total_steps):\n",
    "            if i < warmup_steps:\n",
    "                lr = lr_init + (lr_max - lr_init) * i / warmup_steps\n",
    "            else:\n",
    "                lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)\n",
    "            lr_each_step.append(lr)\n",
    "\n",
    "    lr_each_step = np.array(lr_each_step).astype(np.float32)\n",
    "\n",
    "    return lr_each_step\n",
    "\n",
    "def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):\n",
    "    lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)\n",
    "    lr = float(init_lr) + lr_inc * current_step\n",
    "    return lr\n",
    "\n",
    "def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):\n",
    "    base_lr = lr\n",
    "    warmup_init_lr = 0\n",
    "    total_steps = int(max_epoch * steps_per_epoch)\n",
    "    warmup_steps = int(warmup_epochs * steps_per_epoch)\n",
    "    decay_steps = total_steps - warmup_steps\n",
    "\n",
    "    lr_each_step = []\n",
    "    for i in range(total_steps):\n",
    "        if i < warmup_steps:\n",
    "            lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)\n",
    "        else:\n",
    "            linear_decay = (total_steps - i) / decay_steps\n",
    "            cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))\n",
    "            decayed = linear_decay * cosine_decay + 0.00001\n",
    "            lr = base_lr * decayed\n",
    "        lr_each_step.append(lr)\n",
    "\n",
    "    lr_each_step = np.array(lr_each_step).astype(np.float32)\n",
    "    learning_rate = lr_each_step[global_step:]\n",
    "    return learning_rate\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.nn.loss.loss import _Loss\n",
    "from mindspore.ops import operations as P\n",
    "from mindspore.ops import functional as F\n",
    "from mindspore import Tensor\n",
    "import mindspore.nn as nn\n",
    "\n",
    "class CrossEntropy(_Loss):\n",
    "    def __init__(self, smooth_factor=0., num_classes=1001):\n",
    "        super(CrossEntropy, self).__init__()\n",
    "        self.onehot = P.OneHot()\n",
    "        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)\n",
    "        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)\n",
    "        self.ce = nn.SoftmaxCrossEntropyWithLogits()\n",
    "        self.mean = P.ReduceMean(False)\n",
    "\n",
    "    def construct(self, logit, label):\n",
    "        one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)\n",
    "        loss = self.ce(logit, one_hot_label)\n",
    "        loss = self.mean(loss, 0)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义深度神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本篇使用的MindSpore中的ResNet-50网络模型的源代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.common.tensor import Tensor\n",
    "import mindspore.common.initializer as weight_init\n",
    "\n",
    "def _weight_variable(shape, factor=0.01):\n",
    "    init_value = np.random.randn(*shape).astype(np.float32) * factor\n",
    "    return Tensor(init_value)\n",
    "\n",
    "\n",
    "def _conv3x3(in_channel, out_channel, stride=1):\n",
    "    weight_shape = (out_channel, in_channel, 3, 3)\n",
    "    weight = _weight_variable(weight_shape)\n",
    "    return nn.Conv2d(in_channel, out_channel,\n",
    "                     kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n",
    "\n",
    "\n",
    "def _conv1x1(in_channel, out_channel, stride=1):\n",
    "    weight_shape = (out_channel, in_channel, 1, 1)\n",
    "    weight = _weight_variable(weight_shape)\n",
    "    return nn.Conv2d(in_channel, out_channel,\n",
    "                     kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n",
    "\n",
    "\n",
    "def _conv7x7(in_channel, out_channel, stride=1):\n",
    "    weight_shape = (out_channel, in_channel, 7, 7)\n",
    "    weight = _weight_variable(weight_shape)\n",
    "    return nn.Conv2d(in_channel, out_channel,\n",
    "                     kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)\n",
    "\n",
    "\n",
    "def _bn(channel):\n",
    "    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,\n",
    "                          gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)\n",
    "\n",
    "\n",
    "def _bn_last(channel):\n",
    "    return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,\n",
    "                          gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)\n",
    "\n",
    "\n",
    "def _fc(in_channel, out_channel):\n",
    "    weight_shape = (out_channel, in_channel)\n",
    "    weight = _weight_variable(weight_shape)\n",
    "    return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)\n",
    "\n",
    "\n",
    "class ResidualBlock(nn.Cell):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_channel,\n",
    "                 out_channel,\n",
    "                 stride=1):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "\n",
    "        channel = out_channel // self.expansion\n",
    "        self.conv1 = _conv1x1(in_channel, channel, stride=1)\n",
    "        self.bn1 = _bn(channel)\n",
    "\n",
    "        self.conv2 = _conv3x3(channel, channel, stride=stride)\n",
    "        self.bn2 = _bn(channel)\n",
    "\n",
    "        self.conv3 = _conv1x1(channel, out_channel, stride=1)\n",
    "        self.bn3 = _bn_last(out_channel)\n",
    "\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "        self.down_sample = False\n",
    "\n",
    "        if stride != 1 or in_channel != out_channel:\n",
    "            self.down_sample = True\n",
    "        self.down_sample_layer = None\n",
    "\n",
    "        if self.down_sample:\n",
    "            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),\n",
    "                                                        _bn(out_channel)])\n",
    "        self.add = P.TensorAdd()\n",
    "\n",
    "    def construct(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.down_sample:\n",
    "            identity = self.down_sample_layer(identity)\n",
    "\n",
    "        out = self.add(out, identity)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "class ResNet(nn.Cell):\n",
    "\n",
    "    def __init__(self,\n",
    "                 block,\n",
    "                 layer_nums,\n",
    "                 in_channels,\n",
    "                 out_channels,\n",
    "                 strides,\n",
    "                 num_classes):\n",
    "        super(ResNet, self).__init__()\n",
    "\n",
    "        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:\n",
    "            raise ValueError(\"the length of layer_num, in_channels, out_channels list must be 4!\")\n",
    "\n",
    "        self.conv1 = _conv7x7(3, 64, stride=2)\n",
    "        self.bn1 = _bn(64)\n",
    "        self.relu = P.ReLU()\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode=\"same\")\n",
    "\n",
    "        self.layer1 = self._make_layer(block,\n",
    "                                       layer_nums[0],\n",
    "                                       in_channel=in_channels[0],\n",
    "                                       out_channel=out_channels[0],\n",
    "                                       stride=strides[0])\n",
    "        self.layer2 = self._make_layer(block,\n",
    "                                       layer_nums[1],\n",
    "                                       in_channel=in_channels[1],\n",
    "                                       out_channel=out_channels[1],\n",
    "                                       stride=strides[1])\n",
    "        self.layer3 = self._make_layer(block,\n",
    "                                       layer_nums[2],\n",
    "                                       in_channel=in_channels[2],\n",
    "                                       out_channel=out_channels[2],\n",
    "                                       stride=strides[2])\n",
    "        self.layer4 = self._make_layer(block,\n",
    "                                       layer_nums[3],\n",
    "                                       in_channel=in_channels[3],\n",
    "                                       out_channel=out_channels[3],\n",
    "                                       stride=strides[3])\n",
    "\n",
    "        self.mean = P.ReduceMean(keep_dims=True)\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.end_point = _fc(out_channels[3], num_classes)\n",
    "\n",
    "    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):\n",
    "  \n",
    "        layers = []\n",
    "\n",
    "        resnet_block = block(in_channel, out_channel, stride=stride)\n",
    "        layers.append(resnet_block)\n",
    "\n",
    "        for _ in range(1, layer_num):\n",
    "            resnet_block = block(out_channel, out_channel, stride=1)\n",
    "            layers.append(resnet_block)\n",
    "\n",
    "        return nn.SequentialCell(layers)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        c1 = self.maxpool(x)\n",
    "\n",
    "        c2 = self.layer1(c1)\n",
    "        c3 = self.layer2(c2)\n",
    "        c4 = self.layer3(c3)\n",
    "        c5 = self.layer4(c4)\n",
    "\n",
    "        out = self.mean(c5, (2, 3))\n",
    "        out = self.flatten(out)\n",
    "        out = self.end_point(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "def resnet50(class_num=10):\n",
    "\n",
    "    return ResNet(ResidualBlock,\n",
    "                  [3, 4, 6, 3],\n",
    "                  [64, 256, 512, 1024],\n",
    "                  [256, 512, 1024, 2048],\n",
    "                  [1, 2, 2, 2],\n",
    "                  class_num)\n",
    "\n",
    "def resnet101(class_num=1001):\n",
    "    return ResNet(ResidualBlock,\n",
    "                  [3, 4, 23, 3],\n",
    "                  [64, 256, 512, 1024],\n",
    "                  [256, 512, 1024, 2048],\n",
    "                  [1, 2, 2, 2],\n",
    "                  class_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 定义回调函数Time_per_Step来计算单步训练耗时"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Time_per_Step`用于计算每步训练的时间消耗情况，方便对比混合精度训练和单精度训练的性能区别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.callback import Callback\n",
    "import time\n",
    "\n",
    "class Time_per_Step(Callback):\n",
    "    def step_begin(self, run_context):\n",
    "        cb_params = run_context.original_args()\n",
    "        cb_params.init_time = time.time()\n",
    "        \n",
    "    def step_end(selfself, run_context):\n",
    "        cb_params = run_context.original_args()\n",
    "        one_step_time = (time.time() - cb_params.init_time) * 1000\n",
    "        print(one_step_time, \"ms\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义训练网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 设置混合精度训练并执行训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于MindSpore已经添加了自动混合精度训练功能，我们这里操作起来非常方便，只需要在Model中添加参数`amp_level=O2`就完成了设置GPU模式下的混合精度训练设置。运行时，将会自动混合精度训练模型。\n",
    "\n",
    "`amp_level`的参数详情：\n",
    "\n",
    "`O0`：表示不做任何变化，即单精度训练，系统默认`O0`。\n",
    "\n",
    "`O2`：表示将网络中的参数计算变为float16。适用于GPU环境。\n",
    "\n",
    "`O3`：表示将网络中的参数计算变为float16，同时需要在Model中添加参数`keep_batchnorm_fp32=False`。适用于Ascend环境。\n",
    "\n",
    "在`Model`中设置`amp_level=O2`后即可执行混合精度训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 1, loss is 2.3015203\n",
      "37518.837213516235 ms\n",
      "epoch: 1 step: 2, loss is 2.3068979\n",
      "197.05581665039062 ms\n",
      "epoch: 1 step: 3, loss is 2.3115108\n",
      "189.01705741882324 ms\n",
      "epoch: 1 step: 4, loss is 2.3279507\n",
      "188.4777545928955 ms\n",
      "epoch: 1 step: 5, loss is 2.2853572\n",
      "188.50111961364746 ms\n",
      "epoch: 1 step: 6, loss is 2.2706618\n",
      "188.63296508789062 ms\n",
      "epoch: 1 step: 7, loss is 2.325651\n",
      "213.5298252105713 ms\n",
      "epoch: 1 step: 8, loss is 2.3179858\n",
      "188.95459175109863 ms\n",
      "epoch: 1 step: 9, loss is 2.3060834\n",
      "193.02725791931152 ms\n",
      "epoch: 1 step: 10, loss is 2.39061\n",
      "192.83699989318848 ms\n",
      "\n",
      "\n",
      "......\n",
      "\n",
      "\n",
      "epoch: 8 step: 323, loss is 0.54335135\n",
      "190.31238555908203 ms\n",
      "epoch: 8 step: 324, loss is 0.2202819\n",
      "190.30189514160156 ms\n",
      "\n",
      "\n",
      "......\n",
      "\n",
      "\n",
      "epoch: 10 step: 1545, loss is 0.21533835\n",
      "192.63434410095215 ms\n",
      "epoch: 10 step: 1546, loss is 0.14042784\n",
      "192.5680637359619 ms\n",
      "epoch: 10 step: 1547, loss is 0.14810953\n",
      "192.64483451843262 ms\n",
      "epoch: 10 step: 1548, loss is 0.3791172\n",
      "192.7051544189453 ms\n",
      "epoch: 10 step: 1549, loss is 0.43446764\n",
      "192.60406494140625 ms\n",
      "epoch: 10 step: 1550, loss is 0.16453475\n",
      "192.5489902496338 ms\n",
      "epoch: 10 step: 1551, loss is 0.43192416\n",
      "192.45147705078125 ms\n",
      "epoch: 10 step: 1552, loss is 0.15318932\n",
      "192.69466400146484 ms\n",
      "epoch: 10 step: 1553, loss is 0.18142739\n",
      "192.4266815185547 ms\n",
      "epoch: 10 step: 1554, loss is 0.23418093\n",
      "191.3902759552002 ms\n",
      "epoch: 10 step: 1555, loss is 0.21376474\n",
      "190.4129981994629 ms\n",
      "epoch: 10 step: 1556, loss is 0.26256102\n",
      "190.1836395263672 ms\n",
      "epoch: 10 step: 1557, loss is 0.11623224\n",
      "190.07587432861328 ms\n",
      "epoch: 10 step: 1558, loss is 0.38422704\n",
      "190.2308464050293 ms\n",
      "epoch: 10 step: 1559, loss is 0.1297225\n",
      "190.05846977233887 ms\n",
      "epoch: 10 step: 1560, loss is 0.03785105\n",
      "189.89896774291992 ms\n",
      "epoch: 10 step: 1561, loss is 0.2947039\n",
      "190.0768280029297 ms\n",
      "epoch: 10 step: 1562, loss is 0.41113874\n",
      "190.03891944885254 ms\n",
      "Epoch time: 302610.106, per step time: 193.732\n"
     ]
    }
   ],
   "source": [
    "\"\"\"train ResNet-50\"\"\"\n",
    "import os\n",
    "import random\n",
    "import argparse\n",
    "from mindspore import context\n",
    "from mindspore.nn.optim.momentum import Momentum\n",
    "from mindspore.train.model import Model\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor\n",
    "from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits\n",
    "from mindspore.train.loss_scale_manager import FixedLossScaleManager\n",
    "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
    "\n",
    "\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Image classification')\n",
    "parser.add_argument('--net', type=str, default=\"resnet50\", help='Resnet Model, either resnet50 or resnet101')\n",
    "parser.add_argument('--dataset', type=str, default=\"cifar10\", help='Dataset, either cifar10 or imagenet2012')\n",
    "parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')\n",
    "parser.add_argument('--device_target', type=str, default='GPU', help='Device target')\n",
    "parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')\n",
    "args_opt = parser.parse_known_args()[0]\n",
    "\n",
    "random.seed(1)\n",
    "np.random.seed(1)\n",
    "de.config.set_seed(1)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "\n",
    "    context.set_context(mode=context.GRAPH_MODE,enable_auto_mixed_precision=False, device_target=\"GPU\")\n",
    "    ckpt_save_dir= \"./resnet_ckpt\"\n",
    "    batch_size = 32\n",
    "    epoch_size = 10\n",
    "    dataset_path = \"./datasets/cifar10/train\"\n",
    "    test_path = \"./datasets/cifar10/test\"\n",
    "    \n",
    "    # create dataset\n",
    "    dataset = create_dataset(dataset_path=dataset_path, do_train=True, repeat_num=1,\n",
    "                                 batch_size=batch_size, target=\"GPU\")\n",
    "    step_size = dataset.get_dataset_size()\n",
    "    # define net\n",
    "    net = resnet50(class_num=10)\n",
    "    \n",
    "    # init weight\n",
    "    if args_opt.pre_trained:\n",
    "        param_dict = load_checkpoint(args_opt.pre_trained)\n",
    "        load_param_into_net(net, param_dict)\n",
    "    else:\n",
    "        for _, cell in net.cells_and_names():\n",
    "            if isinstance(cell, nn.Conv2d):\n",
    "                cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),\n",
    "                                                             cell.weight.data.shape,\n",
    "                                                             cell.weight.data.dtype).to_tensor())\n",
    "            if isinstance(cell, nn.Dense):\n",
    "                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),\n",
    "                                                             cell.weight.data.shape,\n",
    "                                                             cell.weight.data.dtype).to_tensor())\n",
    "    # init lr\n",
    "    warmup_epochs = 5\n",
    "    lr_init = 0.01\n",
    "    lr_end = 0.00001\n",
    "    lr_max = 0.1\n",
    "    lr = get_lr(lr_init=lr_init, lr_end=lr_end, lr_max=lr_max,\n",
    "                        warmup_epochs=warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,\n",
    "                        lr_decay_mode='poly')\n",
    "    lr = Tensor(lr)\n",
    "\n",
    "    # define opt\n",
    "    loss_scale = 1024\n",
    "    momentum = 0.9\n",
    "    weight_decay = 1e-4\n",
    "    \n",
    "    # define loss, model\n",
    "    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n",
    "    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)\n",
    "    model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},amp_level=\"O2\")\n",
    "    \n",
    "    # define callbacks\n",
    "    steptime_cb = Time_per_Step()\n",
    "    time_cb = TimeMonitor(data_size=step_size)\n",
    "    loss_cb = LossMonitor()\n",
    "\n",
    "    cb = [time_cb, loss_cb,steptime_cb]\n",
    "    save_checkpoint = 5\n",
    "    if save_checkpoint:\n",
    "        save_checkpoint_epochs = 5\n",
    "        keep_checkpoint_max = 10\n",
    "        config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_epochs * step_size,\n",
    "                                     keep_checkpoint_max=keep_checkpoint_max)\n",
    "        ckpt_cb = ModelCheckpoint(prefix=\"resnet\", directory=ckpt_save_dir, config=config_ck)\n",
    "        cb += [ckpt_cb]\n",
    "\n",
    "    # train model\n",
    "    model.train(epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 验证模型精度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用模型进行精度验证可以得出以下代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: {'acc': 0.8796073717948718}\n"
     ]
    }
   ],
   "source": [
    "# Eval model\n",
    "eval_dataset_path = \"./datasets/cifar10/test\"\n",
    "eval_data = create_dataset(eval_dataset_path,do_train=False)\n",
    "acc = model.eval(eval_data,dataset_sink_mode=True)\n",
    "print(\"Accuracy:\",acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 对比不同网络下的混合精度训练和单精度训练的差别"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于篇幅原因，我们这里只展示了ResNet-50网络的混合精度训练情况。可以在主程序入口的Model中设置参数`amp_level = O0`进行单精度训练，训练完毕后，将结果进行对比，看看两者的情况，下面将我测试的情况做成表格如下。（训练时，笔者使用的GPU为Nvidia Tesla P40，不同的硬件对训练的效率影响较大，下述表格中的数据仅供参考）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|  网络 | 是否混合训练 | 单步训练时间 | epoch | Accuracy\n",
    "|:------  |:-----| :------- |:--- |:------  \n",
    "|ResNet-50 |  否  | 232ms   |  10 |  0.881809 \n",
    "|ResNet-50 |  是  | 192ms   |  10 |  0.879607 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "经过多次测试，使用ResNet-50网络,CIFAR-10数据集，进行混合精度训练对整体的训练效率提升了16%，而且对最终模型的精度影响不大，对整体性能调优来说是一个不容忽视的性能提升。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当然，如果你想参考单步训练或者手动设置混合精度训练，可以参考官网教程<https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html>。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本次体验我们尝试了在ResNet-50网络中使用混合精度来进行模型训练，并对比了单精度下的训练过程，了解到了混合精度训练的原理和对模型训练的提升效果。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}